import streamlit as st
# from groq import Groq
from openai import Client
import json
import os
from io import BytesIO
from markdown import markdown
from weasyprint import HTML, CSS
from dotenv import load_dotenv
from half_json.core import JSONFixer

# load .env file to environment
load_dotenv()

API_KEY = os.getenv("OPENAI_API_KEY", None)
MODEL_NAME = os.getenv("OPENAI_MODEL_NAME", "gpt-4-turbo")

json_fixer = JSONFixer()

if 'api_key' not in st.session_state:
    st.session_state.api_key = API_KEY

if 'client' not in st.session_state:
    if API_KEY:
        st.session_state.client = Client(api_key=API_KEY)

class GenerationStatistics:
    def __init__(self, input_time=0,output_time=0,input_tokens=0,output_tokens=0,total_time=0,model_name=MODEL_NAME):
        self.input_time = input_time
        self.output_time = output_time
        self.input_tokens = input_tokens
        self.output_tokens = output_tokens
        self.total_time = total_time # Sum of queue, prompt (input), and completion (output) times
        self.model_name = model_name

    def get_input_speed(self):
        """
        Tokens per second calculation for input
        """
        if self.input_time != 0:
            return self.input_tokens / self.input_time
        else:
            return 0

    def get_output_speed(self):
        """
        Tokens per second calculation for output
        """
        if self.output_time != 0:
            return self.output_tokens / self.output_time
        else:
            return 0

    def add(self, other):
        """
        Add statistics from another GenerationStatistics object to this one.
        """
        if not isinstance(other, GenerationStatistics):
            raise TypeError("Can only add GenerationStatistics objects")

        self.input_time += other.input_time
        self.output_time += other.output_time
        self.input_tokens += other.input_tokens
        self.output_tokens += other.output_tokens
        self.total_time += other.total_time

    def __str__(self):
        return (f"\n## {self.get_output_speed():.2f} T/s ⚡\nRound trip time: {self.total_time:.2f}s  Model: {self.model_name}\n\n"
                f"| Metric          | Input          | Output          | Total          |\n"
                f"|-----------------|----------------|-----------------|----------------|\n"
                f"| Speed (T/s)     | {self.get_input_speed():.2f}            | {self.get_output_speed():.2f}            | {(self.input_tokens + self.output_tokens) / self.total_time if self.total_time != 0 else 0:.2f}            |\n"
                f"| Tokens          | {self.input_tokens}            | {self.output_tokens}            | {self.input_tokens + self.output_tokens}            |\n"
                f"| Inference Time (s) | {self.input_time:.2f}            | {self.output_time:.2f}            | {self.total_time:.2f}            |")

class Book:
    def __init__(self, book_title, structure):
        self.book_title = book_title
        self.structure = structure
        self.contents = {title: "" for title in self.flatten_structure(structure)}
        self.placeholders = {title: st.empty() for title in self.flatten_structure(structure)}
        st.markdown(f"# {self.book_title}")
        st.markdown("## 生成以下内容:")
        toc_columns = st.columns(4)
        self.display_toc(self.structure, toc_columns)
        st.markdown("---")

    def flatten_structure(self, structure):
        sections = []
        for title, content in structure.items():
            sections.append(title)
            if isinstance(content, dict):
                sections.extend(self.flatten_structure(content))
        return sections

    def update_content(self, title, new_content):
        try:
            self.contents[title] += new_content
            self.display_content(title)
        except TypeError as e:
            pass

    def display_content(self, title):
        if self.contents[title].strip():
            self.placeholders[title].markdown(f"## {title}\n{self.contents[title]}")

    def display_structure(self, structure=None, level=1):
        if structure is None:
            structure = self.structure

        for title, content in structure.items():
            if self.contents[title].strip():  # Only display title if there is content
                st.markdown(f"{'#' * level} {title}")
                self.placeholders[title].markdown(self.contents[title])
            if isinstance(content, dict):
                self.display_structure(content, level + 1)

    def display_toc(self, structure, columns, level=1, col_index=0):
        for title, content in structure.items():
            with columns[col_index % len(columns)]:
                st.markdown(f"{' ' * (level-1) * 2}- {title}")
            col_index += 1
            if isinstance(content, dict):
                col_index = self.display_toc(content, columns, level + 1, col_index)
        return col_index

    def get_markdown_content(self, structure=None, level=1):
        """
        Returns the markdown styled pure string with the contents.
        """
        if structure is None:
            structure = self.structure

        if level==1:
            markdown_content = f"# {self.book_title}\n\n"

        else:
            markdown_content = ""

        for title, content in structure.items():
            if self.contents[title].strip():  # Only include title if there is content
                markdown_content += f"{'#' * level} {title}\n{self.contents[title]}\n\n"
            if isinstance(content, dict):
                markdown_content += self.get_markdown_content(content, level + 1)
        return markdown_content

def create_markdown_file(content: str) -> BytesIO:
    """
    Create a Markdown file from the provided content.
    """
    markdown_file = BytesIO()
    markdown_file.write(content.encode('utf-8'))
    markdown_file.seek(0)
    return markdown_file

def create_pdf_file(content: str) -> BytesIO:
    """
    将 Markdown 内容转换为 PDF 文件。

    参数:
        content (str): Markdown 格式的文本内容。
    """

    # 将 Markdown 转换为 HTML
    html_content = markdown(content, extensions=['extra', 'codehilite'])

    # 定义 CSS 样式，包括字体加载
    css_style = """
    @import url('https://static.zeoseven.com/cn/48/main/result.css');
    @page {{
        margin: 2cm;
    }}
    @font-face {{
        font-family: "Source Han Serif VF";
        font-weight: normal;
    }}
    body {{
        font-family: "SimSun", "宋体", sans-serif;
        line-height: 1.6;
        font-size: 12pt;
    }}
    h1, h2, h3, h4, h5, h6 {{
        color: #333366;
        margin-top: 1em;
        margin-bottom: 0.5em;
    }}
    p {{
        margin-bottom: 0.5em;
    }}
    code {{
        background-color: #f4f4f4;
        padding: 2px 4px;
        border-radius: 4px;
        font-family: monospace;
        font-size: 0.9em;
    }}
    pre {{
        background-color: #f4f4f4;
        padding: 1em;
        border-radius: 4px;
        white-space: pre-wrap;
        word-wrap: break-word;
    }}
    blockquote {{
        border-left: 4px solid #ccc;
        padding-left: 1em;
        margin-left: 0;
        font-style: italic;
    }}
    table {{
        border-collapse: collapse;
        width: 100%;
        margin-bottom: 1em;
    }}
    th, td {{
        border: 1px solid #ddd;
        padding: 8px;
        text-align: left;
    }}
    th {{
        background-color: #f2f2f2;
    }}
    """

    # 创建 HTML 字符串
    styled_html = f"""
    <html>
        <head>
            <meta charset='UTF-8'>
            <style type="text/css">{css_style}</style>
        </head>
        <body>
            {html_content}
        </body>
    </html>
    """
    # 创建 CSS 对象
    css = CSS(string=css_style)


    # 使用 WeasyPrint 将 HTML 转换为 PDF
    pdf_buffer = BytesIO()
    HTML(string=styled_html).write_pdf(pdf_buffer, stylesheets=[css], encoding='utf-8')

    pdf_buffer.seek(0)
    return pdf_buffer

def generate_book_title(prompt: str):
    """
    Generate a book title using AI.
    """
    completion = st.session_state.client.chat.completions.create(
        model=MODEL_NAME,
        messages=[
            {
                "role": "system",
                "content": "Generate suitable book titles for the provided topics. There is only one generated book title! Don't give any explanation or add any symbols, just write the title of the book. The requirement for this title is that it must be between 7 and 25 words long, and it must be attractive enough!"
            },
            {
                "role": "user",
                "content": f"Generate a book title for the following topic. There is only one generated book title! Don't give any explanation or add any symbols, just write the title of the book. The requirement for this title is that it must be at least 7 words and 25 words long, and it must be attractive enough:\n\n{prompt}"
            }
        ],
        temperature=0.7,
        max_tokens=100,
        top_p=1,
        stream=False,
        stop=None,
    )

    return completion.choices[0].message.content.strip()

def generate_book_structure(prompt: str):
    """
    Returns book structure content as well as total tokens and total time for generation.
    """
    completion = st.session_state.client.chat.completions.create(
        model=MODEL_NAME,
        messages=[
            {
                "role": "system",
                "content": "Write in JSON format: {\"Title of section goes here\":\"Description of section goes here\",\"Title of section goes here\":{\"Title of section goes here\":\"Description of section goes here\",\"Title of section goes here\":\"Description of section goes here\",\"Title of section goes here\":\"Description of section goes here\"}}"
            },
            {
                "role": "user",
                "content": f"Write a comprehensive structure in JSON format, omiting introduction and conclusion sections (forward, author's note, summary), for a short (3 page) book on the following subject:\n\n<subject>Data Structures and Algorithms in Java</subject>"
            },
            {
                "role": "assistant",
                "content": '{ "Data Structures and Algorithms in Java" : {"What is Data Structure?" : "Description of Data Structures and Their Importance", "Why Java for Data Structures?" : "Reasons for choosing Java for data structures and algorithms" , "Overview of the Book" : { "Java Basics":"Review of Java Syntax and Basics" } } }'
            },
            {
                "role": "user",
                "content": f"Write a comprehensive structure in JSON format, Do not use ``` to wrap, omiting introduction and conclusion sections (forward, author's note, summary), for a long (>300 page) book on the following subject:\n\n<subject>{prompt}</subject>"
            }
        ],
        temperature=0.3,
        max_tokens=8000,
        top_p=1,
        stream=False,
        response_format={"type": "json_object"},
        stop=None,
    )

    usage = completion.usage
    statistics_to_return = GenerationStatistics(input_time=getattr(usage,"prompt_time",0), output_time=getattr(usage,"completion_time",0), input_tokens=usage.prompt_tokens, output_tokens=usage.completion_tokens, total_time=getattr(usage,"total_time",0),model_name=MODEL_NAME)

    return statistics_to_return, completion.choices[0].message.content

def generate_section(prompt: str):
    stream = st.session_state.client.chat.completions.create(
        model=MODEL_NAME,
        messages=[
            {
                "role": "system",
                "content": "You are an expert writer. Generate a long, comprehensive, structured chapter for the section provided."
            },
            {
                "role": "user",
                "content": f"Generate a long, comprehensive, structured chapter for the following section:\n\n<section_title>{prompt}</section_title>"
            }
        ],
        temperature=0.3,
        max_tokens=8000,
        top_p=1,
        stream=True,
        stop=None,
    )

    for chunk in stream:
        tokens = chunk.choices[0].delta.content
        if tokens:
            yield tokens
        if x_groq := getattr(chunk,"x_groq",None):
            if not x_groq.usage:
                continue
            usage = x_groq.usage
            statistics_to_return = GenerationStatistics(input_time=usage.prompt_time, output_time=usage.completion_time, input_tokens=usage.prompt_tokens, output_tokens=usage.completion_tokens, total_time=usage.total_time,model_name=MODEL_NAME)
            yield statistics_to_return

# Initialize
if 'button_disabled' not in st.session_state:
    st.session_state.button_disabled = False

if 'button_text' not in st.session_state:
    st.session_state.button_text = "创建"

if 'statistics_text' not in st.session_state:
    st.session_state.statistics_text = ""

if 'book_title' not in st.session_state:
    st.session_state.book_title = ""

st.write("""
# 用开源模型生成一本书
""")

def disable():
    st.session_state.button_disabled = True

def enable():
    st.session_state.button_disabled = False

def empty_st():
    st.empty()

try:
    if st.button('创建并下载电子书'):
        if "book" in st.session_state:
            # Create markdown file
            markdown_file = create_markdown_file(st.session_state.book.get_markdown_content())
            st.download_button(
                label='下载文本文件',
                data=markdown_file,
                file_name=f'{st.session_state.book_title}.txt',
                mime='text/plain'
            )

            # Create pdf file (styled)
            pdf_file = create_pdf_file(st.session_state.book.get_markdown_content())
            st.download_button(
                label='下载PDF文件',
                data=pdf_file,
                file_name=f'{st.session_state.book_title}.pdf',
                mime='application/pdf'
            )
        else:
            raise ValueError("请先创建书籍 然后下载")

    with st.form("groqform"):
        if not API_KEY:
            groq_input_key = st.text_input("Enter your API Key (gsk_yA...):", "",type="password")

        topic_text = st.text_input("你想生成哪方面的电子书?", "")

        # Generate button
        submitted = st.form_submit_button(st.session_state.button_text,on_click=disable,disabled=st.session_state.button_disabled)

        # Statistics
        placeholder = st.empty()
        def display_statistics():
            with placeholder.container():
                if st.session_state.statistics_text:
                    if "Generating structure in background" not in st.session_state.statistics_text:
                        st.markdown(st.session_state.statistics_text+"\n\n---\n") # Format with line if showing statistics
                    else:
                        st.markdown(st.session_state.statistics_text)
                else:
                    placeholder.empty()

        if submitted:
            if len(topic_text) < 10:
                raise ValueError("Book topic must be at least 10 characters long")

            st.session_state.button_disabled = True
            st.session_state.statistics_text = "🚀🚀🚀 正在努力创建相关内容 ...."
            display_statistics()

            if not API_KEY:
                st.session_state.client = Client(api_key=API_KEY)

            # Generate AI book title
            st.session_state.book_title = generate_book_title(topic_text)
            st.write(f"## {st.session_state.book_title}")

            large_model_generation_statistics, book_structure = generate_book_structure(topic_text)


            total_generation_statistics = GenerationStatistics(model_name=MODEL_NAME)

            try:
                book_structure = json_fixer.fix(book_structure).line
                book_structure_json = json.loads(book_structure)
                book = Book(st.session_state.book_title, book_structure_json)

                if 'book' not in st.session_state:
                    st.session_state.book = book

                # Print the book structure to the terminal to show structure
                print(json.dumps(book_structure_json, indent=2))

                st.session_state.book.display_structure()

                def stream_section_content(sections):
                    for title, content in sections.items():
                        if isinstance(content, str):
                            content_stream = generate_section(title+": "+content)
                            for chunk in content_stream:
                                # Check if GenerationStatistics data is returned instead of str tokens
                                chunk_data = chunk
                                if (type(chunk_data)==GenerationStatistics):
                                    total_generation_statistics.add(chunk_data)

                                    st.session_state.statistics_text = str(total_generation_statistics)
                                    display_statistics()

                                elif chunk!=None:
                                    st.session_state.book.update_content(title, chunk)
                        elif isinstance(content, dict):
                            stream_section_content(content)

                stream_section_content(book_structure_json)

            except json.JSONDecodeError:
                st.error("Failed to decode the book structure. Please try again.")
                print(book_structure)

            enable()

except Exception as e:
    st.session_state.button_disabled = False
    st.error(e)

    if st.button("清空内容"):
        st.rerun()